package face3wap;

import captcharecognition.MulRecordDataLoader;
import facewap.GanModel;
import lombok.Builder;
import org.apache.commons.lang3.ArrayUtils;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.util.Arrays;

public class Img3Gan {
    private static final double LEARNING_RATE = 0.01;
    private static final double GRADIENT_THRESHOLD = 100.0;
    private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();

    static String genModel = "F:/face/gen.zip";
    static String disModel = "F:/face/dis.zip";
    static String ganModel = "F:/face/gan.zip";
    private static final int seed = 42;

    static int height = 64; // 输入图像高度
    static int width = 64; // 输入图像宽度
    static int channels = 3; // 输入图像通道数
    static int[] inputShape = new int[] {channels, width, height};
    public static final int batch = 8;

    private static JFrame frame;
    private static JPanel panel;

    private static Layer[] genLayers() {
        return new Layer[] {

                new ConvolutionLayer.Builder().kernelSize(5,5).stride(1, 1).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(64).build(),
                //conv_block
                // g2 = 64  / 2  = 32 * 32  * 128  out = in/s
                new ConvolutionLayer.Builder().kernelSize(3,3).stride(2, 2).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(128).build(),
                new ActivationLayer.Builder().activation(Activation.RELU).build(),
                // g3 = 32 / 2  = 16 * 16  * 256
                new ConvolutionLayer.Builder().kernelSize(3,3).stride(2, 2).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(256).build(),
                new ActivationLayer.Builder().activation(Activation.RELU).build(),
                // g4 = 16 / 2  = 8 * 8  * 512
                new ConvolutionLayer.Builder().kernelSize(3,3).stride(2, 2).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(512).build(),
                new ActivationLayer.Builder().activation(Activation.RELU).build(),
                // g5 = 8 / 2  = 4 * 4  * 1024
                 new ConvolutionLayer.Builder().kernelSize(3,3).stride(2,2).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(1024).build(),
                new ActivationLayer.Builder().activation(Activation.RELU).build(),
                //转成1维进行全连接
                //new ReshapeVertex(batch,  4 * 4  * 1024),
                //new PreprocessorVertex(new KerasFlattenRnnPreprocessor(1024,1)),
                new DenseLayer.Builder().nOut(1024).build(),
                new DenseLayer.Builder().nOut(4 * 4 * 1024) .build(),
                //转成4维进行cnn
               // new ReshapeVertex(batch,1024,4, 4),
                // g8 = 4 / 1  = 4 * 4  * 512 * 4  = 8 * 8 * 512
                new ConvolutionLayer.Builder().kernelSize(3,3).stride(1, 1).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(512 * 4).build(),
                new ActivationLayer.Builder().activation(new ActivationLReLU(0.1)).build(),

                //转换为 8 * 8 * 512
               // new ShuffleVertex(batch,512,8, 8),
                new DenseLayer.Builder().nIn(512 * 4).nOut(512 * 4).build(),
               // new DenseLayer.Builder().nIn(2352).nOut(256).weightInit(WeightInit.NORMAL).build(),
                //new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
               /* new DenseLayer.Builder().nIn(128).nOut(256).build(),
                new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
                new DenseLayer.Builder().nIn(512).nOut(1024).build(),
                new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
                new DenseLayer.Builder().nIn(1024).nOut(2352).activation(Activation.TANH).build()*/
        };
    }

    /**
     * Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image.
     *
     * @return config
     */
    private static MultiLayerConfiguration generator() {
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .updater(UPDATER)
                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
                .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
                .weightInit(WeightInit.XAVIER)
                .activation(Activation.IDENTITY)
                .list(genLayers())
                .setInputType(InputType.convolutionalFlat(inputShape[2], inputShape[1], inputShape[0]))
                .build();

        return conf;
    }

    private static Layer[] disLayers() {
        return new Layer[]{
                new DenseLayer.Builder().nIn(2352).nOut(1024).build(),
                new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
                new DropoutLayer.Builder(1 - 0.5).build(),
                new DenseLayer.Builder().nIn(1024).nOut(512).build(),
                new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
                new DropoutLayer.Builder(1 - 0.5).build(),
                new DenseLayer.Builder().nIn(512).nOut(256).build(),
                new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
                new DropoutLayer.Builder(1 - 0.5).build(),
                new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
        };
    }

    private static MultiLayerConfiguration discriminator() {
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .updater(UPDATER)
                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
                .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
                .weightInit(WeightInit.XAVIER)
                .activation(Activation.IDENTITY)
                .list(disLayers())
                .build();

        return conf;
    }

    private static MultiLayerConfiguration gan() {
        Layer[] genLayers = genLayers();
        Layer[] disLayers = Arrays.stream(disLayers())
                .map((layer) -> {
                    if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
                        return new FrozenLayerWithBackprop(layer);
                    } else {
                        return layer;
                    }
                }).toArray(Layer[]::new);
        Layer[] layers = ArrayUtils.addAll(genLayers, disLayers);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .updater(UPDATER)
                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
                .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
                .weightInit(WeightInit.XAVIER)
                .activation(Activation.IDENTITY)
                .list(layers)
                .setInputType(InputType.convolutionalFlat(inputShape[2], inputShape[1], inputShape[0]))
                .build();

        return conf;
    }

    public static void main(String... args) throws Exception {
        Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);

       // MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, seed);

        String inputDataDir = "F:/face/yzm";
        File trainDataFile = new File(inputDataDir + "/train");
        FileSplit trainSplit = new FileSplit(trainDataFile, NativeImageLoader.ALLOWED_FORMATS);


        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); // parent path as the image label
        ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);

       /* ImageTransform transform = new MultiImageTransform(new ShowImageTransform("Display - before "));

        //Initialize the record reader with the train data and the transform chain
        trainRR.initialize(trainSplit,transform);*/
        trainRR.initialize(trainSplit);

        DataSetIterator trainData = new RecordReaderDataSetIterator(trainRR, 99, 1, 1);
        DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
        scaler.fit(trainData);
        trainData.setPreProcessor(scaler);
       /* MultiDataSetIterator trainMulIterator = new MultiRecordDataSetIterator(128, "train");*/

        ComputationGraph net =  GanModel.createEncoder(channels,width,Img3GanGraph.batch);
        net.init();





        MultiLayerNetwork gen = null;
       /* if (new File(genModel).exists()) {
            gen = MultiLayerNetwork.load(new File(genModel),true);

        } else{*/
            gen = new MultiLayerNetwork(generator());
        /*}*/
        MultiLayerNetwork dis = null;
        /*if (new File(disModel).exists()) {
            dis = MultiLayerNetwork.load((new File(disModel)), true);
        } else{*/
            dis = new MultiLayerNetwork(discriminator());
        /*}*/
        MultiLayerNetwork gan = null;
       /* if (new File(ganModel).exists()) {
            gan = MultiLayerNetwork.load((new File(ganModel)), true);
        } else{*/
            gan = new MultiLayerNetwork(gan());
       /* }*/

        gen.init();
        dis.init();
        gan.init();

        System.out.println(gen.summary());
        System.out.println(dis.summary());
        System.out.println(gan.summary());

        copyParams(gen, dis, gan);

        gen.setListeners(new PerformanceListener(10, true));
        dis.setListeners(new PerformanceListener(10, true));
        gan.setListeners(new PerformanceListener(10, true));

        trainData.reset();


        for (int i = 1; i<= 100000; i++) {
            if (!trainData.hasNext()) {
                trainData.reset();
            }
            // generate data
            INDArray real = trainData.next().getFeatures().muli(2).subi(1); //muli对应元素相乘，并覆盖原数组 subi每个元素减去标量，并覆盖原数组
            int batchSize = (int) real.shape()[0];

            /*int shape = 3* 28 * 28;
            real = real.reshape(batchSize,shape);*/


            //-----此段代码证明加载的图片数据是可以被显示的
          /* INDArray[] samples2 = new INDArray[9];
            for(int k=0;k<9;k++){
                samples2[k] = real.getRow(k);
            }
            visualize(samples2);
            Thread.sleep(1000000);*/
            // INDArray fakeIn = Nd4j.rand(batchSize, 100);
            INDArray fakeIn = Nd4j.rand(real.shape());//.reshape(batchSize,shape);
            INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);

            DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
            DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));

            DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));

            dis.fit(data);
            dis.fit(data);

            // Update the discriminator in the GAN network
            updateGan(gen, dis, gan);

            gan.fit(new DataSet(Nd4j.rand(batchSize, real.shape()), Nd4j.zeros(batchSize, 1)));


            if (i % 100 == 1) {
                System.out.println("Iteration " + i/10+ " Visualizing...");
                INDArray[] samples = new INDArray[9];
                DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));

                for (int k = 0; k < 9; k++) {
                    INDArray input = fakeSet2.get(k).getFeatures();
                    //samples[k] = gen.output(input, false);
                    samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);

                }
                visualize(samples);
            }
            // Copy the GANs generator to gen.
            updateGen(gen, gan);
            if (i % 1000 == 0) {
                gen.save(new File(genModel), true);
                dis.save(new File(disModel), true);
                gan.save(new File(ganModel), true);

            }
        }
       // ModelSerializer.writeModel(gan, new File(ganModel + "/minist-model.zip"), true);//保存训练好的网络

        //gen.save(new File("mnist-mlp-generator.dlj"));
    }

    private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
        int genLayerCount = gen.getLayers().length;
        for (int i = 0; i < gan.getLayers().length; i++) {
            if (i < genLayerCount) {
                gen.getLayer(i).setParams(gan.getLayer(i).params());
            } else {
                dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params());
            }
        }
    }

    private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
        for (int i = 0; i < gen.getLayers().length; i++) {
            gen.getLayer(i).setParams(gan.getLayer(i).params());
        }
    }

    private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
        int genLayerCount = gen.getLayers().length;
        for (int i = genLayerCount; i < gan.getLayers().length; i++) {
            gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).params());
        }
    }

    private static void visualize(INDArray[] samples) {
        if (frame == null) {
            frame = new JFrame();
            frame.setTitle("Viz");
            frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
            frame.setLayout(new BorderLayout());

            panel = new JPanel();

            panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
            frame.add(panel, BorderLayout.CENTER);
            frame.setVisible(true);
        }

        panel.removeAll();

        for (INDArray sample : samples) {
            panel.add(getImage(sample));
        }

        frame.revalidate();
        frame.pack();
    }

    private static JLabel getImage(INDArray tensor) {
        //System.out.println(tensor.length());
       /* BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_INT_ARGB_PRE);
        tensor = tensor.reshape(3,28,28);

        for (int i = 0; i < 2352; i++) {
            int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
            bi.getRaster().setSample(i % 84, i / 84, 0, pixel);
        }*/
        ImageIcon orig = new ImageIcon(imageFromINDArray(tensor));
        Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);

        ImageIcon scaled = new ImageIcon(imageScaled);

        return new JLabel(scaled);
    }
    private static BufferedImage imageFromINDArray(INDArray array) {
        array = array.reshape(3,28,28);
        long[] shape = array.shape();
        long height = shape[1];
        long width = shape[2];
        BufferedImage image = new BufferedImage((int)width, (int)height, BufferedImage.TYPE_INT_RGB);
        for (int x = 0; x < width; x++) {
            for (int y = 0; y < height; y++) {
                double red1 = array.getDouble(2, y, x);
                double green1 = array.getDouble( 1,y, x);
                double blue1 = array.getDouble( 0,y, x);
                //handle out of bounds pixel values
                int red = Math.min((int) ( (red1 + 1)* 127.5), 255);
                int green = Math.min((int)((green1 +1)* 127.5), 255);
                int blue = Math.min((int)((blue1 + 1)*127.5), 255);
                red = Math.max(red, 0);
                green = Math.max(green, 0);
                blue = Math.max(blue, 0);
                image.setRGB(x, y, new Color(red, green, blue).getRGB());
            }
        }
        return image;
    }


    /**
     * 显示加载的数据
     * @param samples
     */
    private static void visualizeSrc(INDArray[] samples) {
        if (frame == null) {
            frame = new JFrame();
            frame.setTitle("Viz");
            frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
            frame.setLayout(new BorderLayout());

            panel = new JPanel();

            panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
            frame.add(panel, BorderLayout.CENTER);
            frame.setVisible(true);
        }

        panel.removeAll();

        for (INDArray sample : samples) {
            panel.add(getImageSrc(sample));
        }

        frame.revalidate();
        frame.pack();
    }
    private static JLabel getImageSrc(INDArray tensor) {
        ImageIcon orig = new ImageIcon(imageFromINDArraySrc(tensor));
        Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);

        ImageIcon scaled = new ImageIcon(imageScaled);

        return new JLabel(scaled);
    }
    private static BufferedImage imageFromINDArraySrc(INDArray array) {
        long[] shape = array.shape();
        long height = shape[1];
        long width = shape[2];
        BufferedImage image = new BufferedImage((int)width, (int)height, BufferedImage.TYPE_INT_RGB);
        for (int x = 0; x < width; x++) {
            for (int y = 0; y < height; y++) {
                double red1 = array.getDouble(2, y, x);
                double green1 = array.getDouble( 1,y, x);
                double blue1 = array.getDouble( 0,y, x);
                System.out.println(red1 + "--" + green1 + "===" + blue1);
                //handle out of bounds pixel values
                int red = Math.min((int) (red1), 255);
                int green = Math.min((int)(green1), 255);
                int blue = Math.min((int)(blue1), 255);
                red = Math.max(red, 0);
                green = Math.max(green, 0);
                blue = Math.max(blue, 0);
                image.setRGB(x, y, new Color(red, green, blue).getRGB());
            }
        }
        return image;
    }

}
